#!/usr/bin/env python
import os
import sys
if os.path.exists('/home/chieh/code/wuML'):
sys.path.insert(0,'/home/chieh/code/wuML')
import wuml
import numpy as np
regression test
data is designed such that x₁, x₂ has minor positive impact x₄ has major negative,
5x₁ + x₂ + x₁x₂ - 8x₄ - 2x₄x₄ + δ
data = wuml.wData('../../data/shap_regress_example_gaussian.csv', first_row_is_label=True,
label_type='continuout', label_column_name='label', preprocess_data='center and scale')
model = wuml.regression(data, regressor='linear')
E = wuml.explainer(data, model, explainer_algorithm='shap')
exp = E(data)
| A | B | C | D | y | ŷ | Δy | |
|---|---|---|---|---|---|---|---|
| 0 | 8.331189 | 0.289953 | -0.501698 | -18.577235 | -18.0403 | -13.174735 | 4.865565 |
| 1 | 8.868911 | -1.330877 | -0.487391 | 3.638149 | 7.6894 | 7.971847 | 0.282447 |
| 2 | -1.370143 | 0.302190 | -0.084151 | -11.272708 | -16.0089 | -15.141755 | 0.867145 |
| 3 | 3.119704 | -0.037743 | -0.234172 | -0.866588 | 1.1407 | -0.735743 | 1.876443 |
| 4 | 6.928437 | -0.422388 | -0.168741 | 10.163563 | 12.3256 | 13.783926 | 1.458326 |
| 5 | -14.097783 | 0.588115 | -0.444521 | 9.124437 | -8.9458 | -7.546695 | 1.399105 |
| 6 | 10.958492 | -1.892255 | -0.035028 | 3.970595 | 8.0164 | 10.284860 | 2.268460 |
| 7 | 7.129498 | 1.548023 | -0.089604 | -1.279825 | 8.0926 | 4.591148 | 3.501452 |
| 8 | -5.446437 | -2.511641 | 0.161915 | 0.780784 | -5.9542 | -9.732324 | 3.778124 |
| 9 | 5.557895 | 1.233859 | 0.181624 | 5.039437 | 11.0728 | 9.295871 | 1.776929 |
| 10 | -6.281854 | -1.851778 | 0.841435 | -15.883307 | -28.3988 | -25.892449 | 2.506351 |
| 11 | -3.482064 | -0.696429 | 0.614578 | -4.987805 | -10.1866 | -11.268665 | 1.082065 |
| 12 | -9.218802 | -0.431213 | 0.435844 | -1.360614 | -11.3402 | -13.291730 | 1.951530 |
| 13 | -3.487779 | -1.570089 | 0.001989 | -1.745063 | -6.9249 | -9.517886 | 2.592986 |
| 14 | -0.488487 | 0.174995 | 0.305182 | 5.600323 | 3.2876 | 2.875069 | 0.412531 |
| 15 | -4.327872 | -0.604063 | 0.394624 | 18.262990 | 4.3769 | 11.008735 | 6.631835 |
| 16 | 0.087681 | -0.653717 | 0.803367 | -2.065437 | -3.7108 | -4.545051 | 0.834251 |
| 17 | -5.547747 | -0.119873 | -0.376839 | 1.034297 | -5.5989 | -7.727106 | 2.128206 |
| 18 | 5.085636 | -1.633863 | -0.213362 | 8.591409 | 7.5997 | 9.112875 | 1.513175 |
| 19 | -5.358116 | -0.861983 | 0.143756 | 1.710333 | -4.8953 | -7.082954 | 2.187654 |
| 20 | -6.887116 | 0.878982 | -0.245077 | 16.497683 | 1.5804 | 7.527528 | 5.947128 |
| 21 | 6.898303 | 2.049863 | -0.601795 | 3.902805 | 13.5378 | 9.532232 | 4.005568 |
| 22 | -6.397192 | 1.059833 | 0.189578 | -9.119237 | -18.1950 | -16.983963 | 1.211037 |
| 23 | 0.248218 | 0.968172 | -0.190401 | -4.329414 | -4.4233 | -6.020369 | 1.597069 |
| 24 | -0.779428 | 1.920432 | -0.075597 | -1.500836 | -1.6582 | -3.152374 | 1.494174 |
| 25 | 8.949959 | -1.766824 | 0.623432 | -6.769828 | -4.0952 | -1.680205 | 2.414995 |
| 26 | -6.928679 | 2.105989 | 0.194780 | 9.173654 | -1.3490 | 1.828800 | 3.177800 |
| 27 | 9.156215 | 1.561084 | -0.946357 | -6.181082 | 5.0631 | 0.872915 | 4.190185 |
| 28 | -5.308240 | 2.066571 | 0.121946 | -5.219960 | -11.7617 | -11.056628 | 0.705072 |
| 29 | 4.087603 | -0.363321 | -0.319312 | -6.331519 | -4.6488 | -5.643493 | 0.994693 |
| Most chosen | |
|---|---|
| A | 15 |
| D | 14 |
| B | 1 |
| C | 0 |
| Most weighted | |
|---|---|
| D | 194.980919 |
| A | 170.815480 |
| B | 33.496118 |
| C | 10.028097 |
classification test
data = wuml.wData('../../data/shap_classifier_example.csv', first_row_is_label=True,
label_type='discrete', label_column_name='label')
model = wuml.classification(data, classifier='LogisticRegression')
E = wuml.explainer(data, model, explainer_algorithm='shap')
exp = E(data)
| A | B | C | D | y | ŷ | Δy | |
|---|---|---|---|---|---|---|---|
| 0 | -0.005556 | -0.127778 | 0.016667 | -0.350000 | 0.0 | 0.0 | 0.0 |
| 1 | -0.008333 | -0.019444 | 0.002778 | -0.441667 | 0.0 | 0.0 | 0.0 |
| 2 | 0.333333 | 0.155556 | 0.038889 | 0.005556 | 1.0 | 1.0 | 0.0 |
| 3 | 0.072222 | -0.072222 | -0.038889 | 0.572222 | 1.0 | 1.0 | 0.0 |
| 4 | -0.058333 | -0.058333 | 0.025000 | -0.375000 | 0.0 | 0.0 | 0.0 |
| 5 | 0.030556 | -0.052778 | 0.008333 | -0.452778 | 0.0 | 0.0 | 0.0 |
| 6 | -0.038889 | -0.005556 | -0.011111 | -0.411111 | 0.0 | 0.0 | 0.0 |
| 7 | 0.002778 | 0.008333 | -0.008333 | -0.469444 | 0.0 | 0.0 | 0.0 |
| 8 | -0.005556 | -0.022222 | -0.027778 | -0.411111 | 0.0 | 0.0 | 0.0 |
| 9 | -0.011111 | -0.022222 | 0.022222 | -0.455556 | 0.0 | 0.0 | 0.0 |
| 10 | -0.013889 | 0.002778 | 0.008333 | 0.536111 | 1.0 | 1.0 | 0.0 |
| 11 | 0.016667 | -0.011111 | -0.011111 | 0.538889 | 1.0 | 1.0 | 0.0 |
| 12 | -0.219444 | -0.058333 | -0.008333 | -0.180556 | 0.0 | 0.0 | 0.0 |
| 13 | 0.052778 | 0.075000 | -0.013889 | 0.419444 | 1.0 | 1.0 | 0.0 |
| 14 | 0.036111 | 0.069444 | 0.013889 | 0.413889 | 1.0 | 1.0 | 0.0 |
| 15 | -0.041667 | 0.080556 | 0.013889 | 0.480556 | 1.0 | 1.0 | 0.0 |
| 16 | -0.025000 | 0.008333 | 0.013889 | 0.536111 | 1.0 | 1.0 | 0.0 |
| 17 | 0.097222 | 0.247222 | 0.052778 | 0.136111 | 1.0 | 1.0 | 0.0 |
| 18 | 0.025000 | -0.008333 | 0.008333 | 0.508333 | 1.0 | 1.0 | 0.0 |
| 19 | -0.013889 | 0.052778 | -0.013889 | 0.508333 | 1.0 | 1.0 | 0.0 |
| 20 | -0.016667 | 0.005556 | -0.022222 | -0.433333 | 0.0 | 0.0 | 0.0 |
| 21 | 0.005556 | 0.022222 | 0.005556 | 0.500000 | 1.0 | 1.0 | 0.0 |
| 22 | 0.013889 | -0.075000 | -0.002778 | -0.402778 | 0.0 | 0.0 | 0.0 |
| 23 | -0.288889 | -0.238889 | -0.183333 | 0.244444 | 1.0 | 0.0 | 1.0 |
| 24 | 0.008333 | 0.063889 | 0.013889 | 0.447222 | 1.0 | 1.0 | 0.0 |
| 25 | 0.013889 | 0.030556 | 0.008333 | -0.519444 | 0.0 | 0.0 | 0.0 |
| 26 | -0.041667 | -0.030556 | 0.002778 | -0.397222 | 0.0 | 0.0 | 0.0 |
| 27 | -0.022222 | -0.088889 | 0.005556 | -0.361111 | 0.0 | 0.0 | 0.0 |
| 28 | 0.072222 | 0.133333 | 0.050000 | 0.277778 | 1.0 | 1.0 | 0.0 |
| 29 | 0.030556 | -0.063889 | 0.030556 | -0.463889 | 0.0 | 0.0 | 0.0 |
| Most chosen | |
|---|---|
| D | 26 |
| A | 3 |
| B | 1 |
| C | 0 |
| Most weighted | |
|---|---|
| D | 12.250000 |
| B | 1.911111 |
| A | 1.622222 |
| C | 0.683333 |